{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f02757c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['es:au:in:de:uk2jp:it:mx:fr:ca']\n",
      "es:au:in:de:uk2jp:it:mx:fr:ca\n",
      "['c=0_db=1_r=10.0_H=256_bs=128']\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "exp_folder = 'experiments'\n",
    "exp_main_list = os.listdir(exp_folder)\n",
    "print(exp_main_list)\n",
    "\n",
    "# select an experiment\n",
    "exp_main = exp_main_list[0]\n",
    "print(exp_main)\n",
    "\n",
    "exp_sub_list = os.listdir(exp_folder+'/'+exp_main)\n",
    "print(exp_sub_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b93e33b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "experiments/es:au:in:de:uk2jp:it:mx:fr:ca/c=0_db=1_r=10.0_H=256_bs=128/es:au:in:de:uk2jp:it:mx:fr:ca.pickle\n",
      "{'c=0_db=1_r=10.0_H=256_bs=128': 0.14936458918959689}\n",
      "c=0_db=1_r=10.0_H=256_bs=128   0.14936458918959689\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "def get_metric(results, item_type, metric):\n",
    "    metric_list = []\n",
    "    for key, value in results.items():\n",
    "        metric_list.append(results[key][item_type][metric])\n",
    "        #print(results[key][item_type][metric])\n",
    "    return np.mean(metric_list)\n",
    "\n",
    "plt.figure()\n",
    "\n",
    "seen, test = exp_main.split('2')\n",
    "seen = seen.split(':')\n",
    "test = test.split(':')\n",
    "\n",
    "ndcg_dict = {}\n",
    "recall_dict = {}\n",
    "for exp_sub in exp_sub_list:\n",
    "    #print(exp_sub)\n",
    "    c, db, r, f, bs = exp_sub.split('_')\n",
    "    c = c.split('=')[1]\n",
    "    db = db.split('=')[1]\n",
    "    r = r.split('=')[1]\n",
    "    f = f.split('=')[1]\n",
    "    bs = bs.split('=')[1]\n",
    "    \n",
    "    pkl_folder = exp_folder +'/'+ exp_main +'/'+ exp_sub\n",
    "    file_list = os.listdir(pkl_folder)\n",
    "    pkl_file_list = []\n",
    "    for file in file_list:\n",
    "        if file[-7:]=='.pickle':\n",
    "            pkl_file_list.append(file) \n",
    "    \n",
    "\n",
    "    #print('pkl_folder: ',pkl_folder)\n",
    "    #print('pkl_file_list: ',pkl_file_list)\n",
    "    for pkl_file in pkl_file_list:\n",
    "        train_domain_str = pkl_file.split('2')[0]\n",
    "        test_domain_str = pkl_file.split('2')[1][:-7]\n",
    "        \n",
    "        pkl_file = exp_folder +'/'+ exp_main +'/'+ exp_sub +'/'+pkl_file\n",
    "\n",
    "        with open(pkl_file, 'rb') as handle:\n",
    "            results = pickle.load(handle)\n",
    "            print(pkl_file)\n",
    "        #for key in results.keys():\n",
    "        #    print(results[key]['unseen'])\n",
    "        ndcg_dict[exp_sub] = get_metric(results, 'unseen', 'rNDCG')\n",
    "        recall_dict[exp_sub] = get_metric(results, 'unseen', 'rRecall')\n",
    "print(ndcg_dict)\n",
    "        \n",
    "#plt.plot(ndcg_dict.keys(), ndcg_dict.values(), label=exp_sub)\n",
    "\n",
    "for key, value in ndcg_dict.items():\n",
    "    print(key+' '*(30-len(key)), value)\n",
    "\n",
    "#plt.legend()\n",
    "#plt.xticks(rotation=45, ha='right')\n",
    "#plt.show()\n",
    "\n",
    "            \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f469bc4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'jp': {'random': {'count': 0,\n",
       "   'rNDCG': 0.00021033710343167268,\n",
       "   'rRecall': 0.0003983887388783144},\n",
       "  'mean': {'count': 56771,\n",
       "   'rNDCG': 0.15890044206841103,\n",
       "   'rRecall': 0.238126273390757},\n",
       "  'seen': {'count': 53538,\n",
       "   'rNDCG': 0.15354524199600292,\n",
       "   'rRecall': 0.23254510814748403},\n",
       "  'unseen': {'count': 3233,\n",
       "   'rNDCG': 0.24758176018674158,\n",
       "   'rRecall': 0.3305495411898031}},\n",
       " 'it': {'random': {'count': 0,\n",
       "   'rNDCG': 0.00019789900283233015,\n",
       "   'rRecall': 0.0003849762998965376},\n",
       "  'mean': {'count': 175830,\n",
       "   'rNDCG': 0.20550147497500434,\n",
       "   'rRecall': 0.2940269009839049},\n",
       "  'seen': {'count': 174758,\n",
       "   'rNDCG': 0.20566619456640947,\n",
       "   'rRecall': 0.29450296982112406},\n",
       "  'unseen': {'count': 1072,\n",
       "   'rNDCG': 0.17864880113657866,\n",
       "   'rRecall': 0.216417910447741}},\n",
       " 'mx': {'random': {'count': 0,\n",
       "   'rNDCG': 0.00020081276721162385,\n",
       "   'rRecall': 0.0003944773175542406},\n",
       "  'mean': {'count': 70106,\n",
       "   'rNDCG': 0.09412937566199892,\n",
       "   'rRecall': 0.13449395677022413},\n",
       "  'seen': {'count': 55807,\n",
       "   'rNDCG': 0.09632631051605038,\n",
       "   'rRecall': 0.13757234755496622},\n",
       "  'unseen': {'count': 14299,\n",
       "   'rNDCG': 0.08555504575081223,\n",
       "   'rRecall': 0.12247942746578928}},\n",
       " 'fr': {'random': {'count': 0,\n",
       "   'rNDCG': 0.00019843353188543275,\n",
       "   'rRecall': 0.000390360158609496},\n",
       "  'mean': {'count': 133317,\n",
       "   'rNDCG': 0.19918922998500865,\n",
       "   'rRecall': 0.28165575282972166},\n",
       "  'seen': {'count': 132142,\n",
       "   'rNDCG': 0.19986115657545447,\n",
       "   'rRecall': 0.2827753477319853},\n",
       "  'unseen': {'count': 1175,\n",
       "   'rNDCG': 0.12362350784483539,\n",
       "   'rRecall': 0.15574468085105056}},\n",
       " 'ca': {'random': {'count': 0,\n",
       "   'rNDCG': 0.00019538200592866706,\n",
       "   'rRecall': 0.0003924488810175495},\n",
       "  'mean': {'count': 293281,\n",
       "   'rNDCG': 0.09403399906484175,\n",
       "   'rRecall': 0.13356473825443857},\n",
       "  'seen': {'count': 219654,\n",
       "   'rNDCG': 0.0882083601598989,\n",
       "   'rRecall': 0.12435921949975871},\n",
       "  'unseen': {'count': 73627,\n",
       "   'rNDCG': 0.11141383102901666,\n",
       "   'rRecall': 0.1610278837926302}}}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a49b94f1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'jp:it:mx:fr:ca:us'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_domain_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "e31ab72b",
   "metadata": {},
   "outputs": [],
   "source": [
    "l1 = [4,5,1,3,2]\n",
    "l2 = ['h','e','l','l','o']\n",
    "l3 = ['p','l','a','n','e']\n",
    "ls = zip(l1, l2, l3)\n",
    "\n",
    "s_ls = sorted(zip(l1, l2, l3))\n",
    "[(1, 'l', 'a'), (2, 'o', 'e'), (3, 'l', 'n'), (4, 'h', 'p'), (5, 'e', 'l')]\n",
    "\n",
    "z_s_ls = zip(*s_ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "134c46bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(1, 'l', 'a'), (2, 'o', 'e'), (3, 'l', 'n'), (4, 'h', 'p'), (5, 'e', 'l')]\n"
     ]
    }
   ],
   "source": [
    "print(s_ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "bc5f1469",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<zip object at 0x7fd4534ce9b0>\n",
      "(1, 2, 3, 4, 5)\n",
      "('l', 'o', 'l', 'h', 'e')\n",
      "('a', 'e', 'n', 'p', 'l')\n"
     ]
    }
   ],
   "source": [
    "z_s_ls = zip(*s_ls)\n",
    "print(z_s_ls)\n",
    "for x in z_s_ls:\n",
    "    print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c7c7e2f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_amazonei_pytorch_latest_p37",
   "language": "python",
   "name": "conda_amazonei_pytorch_latest_p37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
